"""Tool execution node for LangGraph workflows.

This module provides prebuilt functionality for executing tools in LangGraph.

Tools are functions that models can call to interact with external systems,
APIs, databases, or perform computations.

The module implements design patterns for:

- Parallel execution of multiple tool calls for efficiency
- Robust error handling with customizable error messages
- State injection for tools that need access to graph state
- Store injection for tools that need persistent storage
- Command-based state updates for advanced control flow

Key Components:

- `ToolNode`: Main class for executing tools in LangGraph workflows
- `InjectedState`: Annotation for injecting graph state into tools
- `InjectedStore`: Annotation for injecting persistent store into tools
- `ToolRuntime`: Runtime information for tools, bundling together `state`, `context`,
    `config`, `stream_writer`, `tool_call_id`, and `store`
- `tools_condition`: Utility function for conditional routing based on tool calls

Typical Usage:
    ```python
    from langchain_core.tools import tool
    from langchain.tools import ToolNode


    @tool
    def my_tool(x: int) -> str:
        return f"Result: {x}"


    tool_node = ToolNode([my_tool])
    ```
"""

from __future__ import annotations

import asyncio
import inspect
import json
from collections.abc import Awaitable, Callable
from copy import copy, deepcopy
from dataclasses import dataclass, replace
from types import UnionType
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Generic,
    Literal,
    TypedDict,
    Union,
    cast,
    get_args,
    get_origin,
    get_type_hints,
)

from langchain_core.messages import (
    AIMessage,
    AnyMessage,
    RemoveMessage,
    ToolCall,
    ToolMessage,
    convert_to_messages,
)
from langchain_core.runnables.config import (
    RunnableConfig,
    get_config_list,
    get_executor_for_config,
)
from langchain_core.tools import BaseTool, InjectedToolArg
from langchain_core.tools import tool as create_tool
from langchain_core.tools.base import (
    TOOL_MESSAGE_BLOCK_TYPES,
    ToolException,
    _DirectlyInjectedToolArg,
    get_all_basemodel_annotations,
)
from langgraph._internal._runnable import RunnableCallable
from langgraph.errors import GraphBubbleUp
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.store.base import BaseStore  # noqa: TC002
from langgraph.types import Command, Send, StreamWriter
from pydantic import BaseModel, ValidationError
from typing_extensions import TypeVar, Unpack

if TYPE_CHECKING:
    from collections.abc import Sequence

    from langgraph.runtime import Runtime
    from pydantic_core import ErrorDetails

# right now we use a dict as the default, can change this to AgentState, but depends
# on if this lives in LangChain or LangGraph... ideally would have some typed
# messages key
StateT = TypeVar("StateT", default=dict)
ContextT = TypeVar("ContextT", default=None)

INVALID_TOOL_NAME_ERROR_TEMPLATE = (
    "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
)
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
TOOL_EXECUTION_ERROR_TEMPLATE = (
    "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
    " {error}\n"
    " Please fix the error and try again."
)
TOOL_INVOCATION_ERROR_TEMPLATE = (
    "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
    " {error}\n"
    " Please fix the error and try again."
)


class _ToolCallRequestOverrides(TypedDict, total=False):
    """Possible overrides for ToolCallRequest.override() method."""

    tool_call: ToolCall


@dataclass
class ToolCallRequest:
    """Tool execution request passed to tool call interceptors.

    Attributes:
        tool_call: Tool call dict with name, args, and id from model output.
        tool: BaseTool instance to be invoked, or None if tool is not
            registered with the `ToolNode`. When tool is `None`, interceptors can
            handle the request without validation. If the interceptor calls `execute()`,
            validation will occur and raise an error for unregistered tools.
        state: Agent state (`dict`, `list`, or `BaseModel`).
        runtime: LangGraph runtime context (optional, `None` if outside graph).
    """

    tool_call: ToolCall
    tool: BaseTool | None
    state: Any
    runtime: ToolRuntime

    def __setattr__(self, name: str, value: Any) -> None:
        """Raise deprecation warning when setting attributes directly.

        Direct attribute assignment is deprecated. Use the `override()` method instead.
        """
        import warnings

        # Allow setting attributes during initialization
        if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
            object.__setattr__(self, name, value)
        else:
            warnings.warn(
                f"Setting attribute '{name}' on ToolCallRequest is deprecated. "
                "Use the override() method instead to create a new instance with modified values.",
                DeprecationWarning,
                stacklevel=2,
            )
            object.__setattr__(self, name, value)

    def override(
        self, **overrides: Unpack[_ToolCallRequestOverrides]
    ) -> ToolCallRequest:
        """Replace the request with a new request with the given overrides.

        Returns a new `ToolCallRequest` instance with the specified attributes replaced.
        This follows an immutable pattern, leaving the original request unchanged.

        Args:
            **overrides: Keyword arguments for attributes to override. Supported keys:
                - tool_call: Tool call dict with name, args, and id

        Returns:
            New ToolCallRequest instance with specified overrides applied.

        Examples:
            ```python
            # Modify tool call arguments without mutating original
            modified_call = {**request.tool_call, "args": {"value": 10}}
            new_request = request.override(tool_call=modified_call)

            # Override multiple attributes
            new_request = request.override(tool_call=modified_call, state=new_state)
            ```
        """
        return replace(self, **overrides)


ToolCallWrapper = Callable[
    [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
    ToolMessage | Command,
]
"""Wrapper for tool call execution with multi-call support.

Wrapper receives:
    request: ToolCallRequest with tool_call, tool, state, and runtime.
    execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).

Returns:
    ToolMessage or Command (the final result).

The execute callable can be invoked multiple times for retry logic,
with potentially modified requests each time. Each call to execute
is independent and stateless.

!!! note
    When implementing middleware for `create_agent`, use
    `AgentMiddleware.wrap_tool_call` which provides properly typed
    state parameter for better type safety.

Examples:
    Passthrough (execute once):

    def handler(request, execute):
        return execute(request)

    Modify request before execution:

    ```python
    def handler(request, execute):
        modified_call = {**request.tool_call, "args": {**request.tool_call["args"], "value": request.tool_call["args"]["value"] * 2}}
        modified_request = request.override(tool_call=modified_call)
        return execute(modified_request)
    ```

    Retry on error (execute multiple times):

    ```python
    def handler(request, execute):
        for attempt in range(3):
            try:
                result = execute(request)
                if is_valid(result):
                    return result
            except Exception:
                if attempt == 2:
                    raise
        return result
    ```

    Conditional retry based on response:

    ```python
    def handler(request, execute):
        for attempt in range(3):
            result = execute(request)
            if isinstance(result, ToolMessage) and result.status != "error":
                return result
            if attempt < 2:
                continue
            return result
    ```

    Cache/short-circuit without calling execute:

    ```python
    def handler(request, execute):
        if cached := get_cache(request):
            return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
        result = execute(request)
        save_cache(request, result)
        return result
    ```
"""

AsyncToolCallWrapper = Callable[
    [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
    Awaitable[ToolMessage | Command],
]
"""Async wrapper for tool call execution with multi-call support."""


class ToolCallWithContext(TypedDict):
    """ToolCall with additional context for graph state.

    This is an internal data structure meant to help the `ToolNode` accept
    tool calls with additional context (e.g. state) when dispatched using the
    Send API.

    The Send API is used in create_agent to distribute tool calls in parallel
    and support human-in-the-loop workflows where graph execution may be paused
    for an indefinite time.
    """

    tool_call: ToolCall
    __type: Literal["tool_call_with_context"]
    """Type to parameterize the payload.

    Using "__" as a prefix to be defensive against potential name collisions with
    regular user state.
    """
    state: Any
    """The state is provided as additional context."""


def msg_content_output(output: Any) -> str | list[dict]:
    """Convert tool output to `ToolMessage` content format.

    Handles `str`, `list[dict]` (content blocks), and arbitrary objects by attempting
    JSON serialization with fallback to str().

    Args:
        output: Tool execution output of any type.

    Returns:
        String or list of content blocks suitable for `ToolMessage.content`.
    """
    if isinstance(output, str) or (
        isinstance(output, list)
        and all(
            isinstance(x, dict) and x.get("type") in TOOL_MESSAGE_BLOCK_TYPES
            for x in output
        )
    ):
        return output
    # Technically a list of strings is also valid message content, but it's
    # not currently well tested that all chat models support this.
    # And for backwards compatibility we want to make sure we don't break
    # any existing ToolNode usage.
    try:
        return json.dumps(output, ensure_ascii=False)
    except Exception:  # noqa: BLE001
        return str(output)


class ToolInvocationError(ToolException):
    """An error occurred while invoking a tool due to invalid arguments.

    This exception is only raised when invoking a tool using the `ToolNode`!
    """

    def __init__(
        self,
        tool_name: str,
        source: ValidationError,
        tool_kwargs: dict[str, Any],
        filtered_errors: list[ErrorDetails] | None = None,
    ) -> None:
        """Initialize the ToolInvocationError.

        Args:
            tool_name: The name of the tool that failed.
            source: The exception that occurred.
            tool_kwargs: The keyword arguments that were passed to the tool.
            filtered_errors: Optional list of filtered validation errors excluding
                injected arguments.
        """
        # Format error display based on filtered errors if provided
        if filtered_errors is not None:
            # Manually format the filtered errors without URLs or fancy formatting
            error_str_parts = []
            for error in filtered_errors:
                loc_str = ".".join(str(loc) for loc in error.get("loc", ()))
                msg = error.get("msg", "Unknown error")
                error_str_parts.append(f"{loc_str}: {msg}")
            error_display_str = "\n".join(error_str_parts)
        else:
            error_display_str = str(source)

        self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format(
            tool_name=tool_name, tool_kwargs=tool_kwargs, error=error_display_str
        )
        self.tool_name = tool_name
        self.tool_kwargs = tool_kwargs
        self.source = source
        self.filtered_errors = filtered_errors
        super().__init__(self.message)


def _default_handle_tool_errors(e: Exception) -> str:
    """Default error handler for tool errors.

    If the tool is a tool invocation error, return its message.
    Otherwise, raise the error.
    """
    if isinstance(e, ToolInvocationError):
        return e.message
    raise e


def _handle_tool_error(
    e: Exception,
    *,
    flag: bool
    | str
    | Callable[..., str]
    | type[Exception]
    | tuple[type[Exception], ...],
) -> str:
    """Generate error message content based on exception handling configuration.

    This function centralizes error message generation logic, supporting different
    error handling strategies configured via the `ToolNode`'s `handle_tool_errors`
    parameter.

    Args:
        e: The exception that occurred during tool execution.
        flag: Configuration for how to handle the error. Can be:
            - bool: If `True`, use default error template
            - str: Use this string as the error message
            - Callable: Call this function with the exception to get error message
            - tuple: Not used in this context (handled by caller)

    Returns:
        A string containing the error message to include in the `ToolMessage`.

    Raises:
        ValueError: If flag is not one of the supported types.

    !!! note
        The tuple case is handled by the caller through exception type checking,
        not by this function directly.
    """
    if isinstance(flag, (bool, tuple)) or (
        isinstance(flag, type) and issubclass(flag, Exception)
    ):
        content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
    elif isinstance(flag, str):
        content = flag
    elif callable(flag):
        content = flag(e)  # type: ignore [assignment, call-arg]
    else:
        msg = (
            f"Got unexpected type of `handle_tool_error`. Expected bool, str "
            f"or callable. Received: {flag}"
        )
        raise ValueError(msg)
    return content


def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
    """Infer exception types handled by a custom error handler function.

    This function analyzes the type annotations of a custom error handler to determine
    which exception types it's designed to handle. This enables type-safe error handling
    where only specific exceptions are caught and processed by the handler.

    Args:
        handler: A callable that takes an exception and returns an error message string.
                The first parameter (after self/cls if present) should be type-annotated
                with the exception type(s) to handle.

    Returns:
        A tuple of exception types that the handler can process. Returns (Exception,)
        if no specific type information is available for backward compatibility.

    Raises:
        ValueError: If the handler's annotation contains non-Exception types or
            if Union types contain non-Exception types.

    !!! note
        This function supports both single exception types and Union types for
        handlers that need to handle multiple exception types differently.
    """
    sig = inspect.signature(handler)
    params = list(sig.parameters.values())
    if params:
        # If it's a method, the first argument is typically 'self' or 'cls'
        if params[0].name in ["self", "cls"] and len(params) == 2:
            first_param = params[1]
        else:
            first_param = params[0]

        type_hints = get_type_hints(handler)
        if first_param.name in type_hints:
            origin = get_origin(first_param.annotation)
            if origin in [Union, UnionType]:
                args = get_args(first_param.annotation)
                if all(issubclass(arg, Exception) for arg in args):
                    return tuple(args)
                msg = (
                    "All types in the error handler error annotation must be "
                    "Exception types. For example, "
                    "`def custom_handler(e: Union[ValueError, TypeError])`. "
                    f"Got '{first_param.annotation}' instead."
                )
                raise ValueError(msg)

            exception_type = type_hints[first_param.name]
            if Exception in exception_type.__mro__:
                return (exception_type,)
            msg = (
                f"Arbitrary types are not supported in the error handler "
                f"signature. Please annotate the error with either a "
                f"specific Exception type or a union of Exception types. "
                "For example, `def custom_handler(e: ValueError)` or "
                "`def custom_handler(e: Union[ValueError, TypeError])`. "
                f"Got '{exception_type}' instead."
            )
            raise ValueError(msg)

    # If no type information is available, return (Exception,)
    # for backwards compatibility.
    return (Exception,)


def _filter_validation_errors(
    validation_error: ValidationError,
    injected_args: _InjectedArgs | None,
) -> list[ErrorDetails]:
    """Filter validation errors to only include LLM-controlled arguments.

    When a tool invocation fails validation, only errors for arguments that the LLM
    controls should be included in error messages. This ensures the LLM receives
    focused, actionable feedback about parameters it can actually fix. System-injected
    arguments (state, store, runtime) are filtered out since the LLM has no control
    over them.

    This function also removes injected argument values from the `input` field in error
    details, ensuring that only LLM-provided arguments appear in error messages.

    Args:
        validation_error: The Pydantic ValidationError raised during tool invocation.
        injected_args: The _InjectedArgs structure containing all injected arguments,
            or None if there are no injected arguments.

    Returns:
        List of ErrorDetails containing only errors for LLM-controlled arguments,
        with system-injected argument values removed from the input field.
    """
    # Collect all injected argument names
    injected_arg_names: set[str] = set()
    if injected_args:
        if injected_args.state:
            injected_arg_names.update(injected_args.state.keys())
        if injected_args.store:
            injected_arg_names.add(injected_args.store)
        if injected_args.runtime:
            injected_arg_names.add(injected_args.runtime)

    filtered_errors: list[ErrorDetails] = []
    for error in validation_error.errors():
        # Check if error location contains any injected argument
        # error['loc'] is a tuple like ('field_name',) or ('field_name', 'nested_field')
        if error["loc"] and error["loc"][0] not in injected_arg_names:
            # Create a copy of the error dict to avoid mutating the original
            error_copy: dict[str, Any] = {**error}

            # Remove injected arguments from input_value if it's a dict
            if isinstance(error_copy.get("input"), dict):
                input_dict = error_copy["input"]
                input_copy = {
                    k: v for k, v in input_dict.items() if k not in injected_arg_names
                }
                error_copy["input"] = input_copy

            # Cast is safe because ErrorDetails is a TypedDict compatible with this structure
            filtered_errors.append(error_copy)  # type: ignore[arg-type]

    return filtered_errors


@dataclass
class _InjectedArgs:
    """Internal structure for tracking injected arguments for a tool.

    This data structure is built once during ToolNode initialization by analyzing
    the tool's signature and args schema, then reused during execution for efficient
    injection without repeated reflection.

    The structure maps from tool parameter names to their injection sources, enabling
    the ToolNode to know exactly which arguments need to be injected and where to
    get their values from.

    Attributes:
        state: Mapping from tool parameter names to state field names for injection.
            Keys are tool parameter names, values are either:
            - str: Name of the state field to extract and inject
            - None: Inject the entire state object
            Empty dict if no state injection is needed.
        store: Name of the tool parameter where the store should be injected,
            or None if no store injection is needed.
        runtime: Name of the tool parameter where the runtime should be injected,
            or None if no runtime injection is needed.

    Example:
        For a tool with signature:
        ```python
        def my_tool(
            x: int,
            messages: Annotated[list, InjectedState("messages")],
            full_state: Annotated[dict, InjectedState()],
            store: Annotated[BaseStore, InjectedStore()],
            runtime: ToolRuntime,
        ) -> str:
            ...
        ```

        The resulting `_InjectedArgs` would be:
        ```python
        _InjectedArgs(
            state={
                "messages": "messages",  # Extract state["messages"]
                "full_state": None,      # Inject entire state
            },
            store="store",               # Inject into "store" parameter
            runtime="runtime",           # Inject into "runtime" parameter
        )
        ```
    """

    state: dict[str, str | None]
    store: str | None
    runtime: str | None


class ToolNode(RunnableCallable):
    """A node for executing tools in LangGraph workflows.

    Handles tool execution patterns including function calls, state injection,
    persistent storage, and control flow. Manages parallel execution,
    error handling.

    Input Formats:
        1. Graph state with `messages` key that has a list of messages:
            - Common representation for agentic workflows
            - Supports custom messages key via `messages_key` parameter

        2. **Message List**: `[AIMessage(..., tool_calls=[...])]`
            - List of messages with tool calls in the last AIMessage

        3. **Direct Tool Calls**: `[{"name": "tool", "args": {...}, "id": "1", "type": "tool_call"}]`
            - Bypasses message parsing for direct tool execution
            - For programmatic tool invocation and testing

    Output Formats:
        Output format depends on input type and tool behavior:

        **For Regular tools**:

        - Dict input → `{"messages": [ToolMessage(...)]}`
        - List input → `[ToolMessage(...)]`

        **For Command tools**:

        - Returns `[Command(...)]` or mixed list with regular tool outputs
        - `Command` can update state, trigger navigation, or send messages

    Args:
        tools: A sequence of tools that can be invoked by this node.

            Supports:

            - **BaseTool instances**: Tools with schemas and metadata
            - **Plain functions**: Automatically converted to tools with inferred schemas

        name: The name identifier for this node in the graph. Used for debugging
            and visualization.
        tags: Optional metadata tags to associate with the node for filtering
            and organization.
        handle_tool_errors: Configuration for error handling during tool execution.
            Supports multiple strategies:

            - `True`: Catch all errors and return a `ToolMessage` with the default
                error template containing the exception details.
            - `str`: Catch all errors and return a `ToolMessage` with this custom
                error message string.
            - `type[Exception]`: Only catch exceptions with the specified type and
                return the default error message for it.
            - `tuple[type[Exception], ...]`: Only catch exceptions with the specified
                types and return default error messages for them.
            - `Callable[..., str]`: Catch exceptions matching the callable's signature
                and return the string result of calling it with the exception.
            - `False`: Disable error handling entirely, allowing exceptions to
                propagate.

            Defaults to a callable that:

            - Catches tool invocation errors (due to invalid arguments provided by the
                model) and returns a descriptive error message
            - Ignores tool execution errors (they will be re-raised)

        messages_key: The key in the state dictionary that contains the message list.
            This same key will be used for the output `ToolMessage` objects.

            Allows custom state schemas with different message field names.

    Examples:
        Basic usage:

        ```python
        from langchain.tools import ToolNode
        from langchain_core.tools import tool

        @tool
        def calculator(a: int, b: int) -> int:
            \"\"\"Add two numbers.\"\"\"
            return a + b

        tool_node = ToolNode([calculator])
        ```

        State injection:

        ```python
        from typing_extensions import Annotated
        from langchain.tools import InjectedState

        @tool
        def context_tool(query: str, state: Annotated[dict, InjectedState]) -> str:
            \"\"\"Some tool that uses state.\"\"\"
            return f"Query: {query}, Messages: {len(state['messages'])}"

        tool_node = ToolNode([context_tool])
        ```

        Error handling:

        ```python
        def handle_errors(e: ValueError) -> str:
            return "Invalid input provided"


        tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
        ```
    """  # noqa: E501

    name: str = "tools"

    def __init__(
        self,
        tools: Sequence[BaseTool | Callable],
        *,
        name: str = "tools",
        tags: list[str] | None = None,
        handle_tool_errors: bool
        | str
        | Callable[..., str]
        | type[Exception]
        | tuple[type[Exception], ...] = _default_handle_tool_errors,
        messages_key: str = "messages",
        wrap_tool_call: ToolCallWrapper | None = None,
        awrap_tool_call: AsyncToolCallWrapper | None = None,
    ) -> None:
        """Initialize `ToolNode` with tools and configuration.

        Args:
            tools: Sequence of tools to make available for execution.
            name: Node name for graph identification.
            tags: Optional metadata tags.
            handle_tool_errors: Error handling configuration.
            messages_key: State key containing messages.
            wrap_tool_call: Sync wrapper function to intercept tool execution. Receives
                ToolCallRequest and execute callable, returns ToolMessage or Command.
                Enables retries, caching, request modification, and control flow.
            awrap_tool_call: Async wrapper function to intercept tool execution.
                If not provided, falls back to wrap_tool_call for async execution.
        """
        super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
        self._tools_by_name: dict[str, BaseTool] = {}
        self._injected_args: dict[str, _InjectedArgs] = {}
        self._handle_tool_errors = handle_tool_errors
        self._messages_key = messages_key
        self._wrap_tool_call = wrap_tool_call
        self._awrap_tool_call = awrap_tool_call
        for tool in tools:
            if not isinstance(tool, BaseTool):
                tool_ = create_tool(cast("type[BaseTool]", tool))
            else:
                tool_ = tool
            self._tools_by_name[tool_.name] = tool_
            # Build injected args mapping once during initialization in a single pass
            self._injected_args[tool_.name] = _get_all_injected_args(tool_)

    @property
    def tools_by_name(self) -> dict[str, BaseTool]:
        """Mapping from tool name to BaseTool instance."""
        return self._tools_by_name

    def _func(
        self,
        input: list[AnyMessage] | dict[str, Any] | BaseModel,
        config: RunnableConfig,
        runtime: Runtime,
    ) -> Any:
        tool_calls, input_type = self._parse_input(input)
        config_list = get_config_list(config, len(tool_calls))

        # Construct ToolRuntime instances at the top level for each tool call
        tool_runtimes = []
        for call, cfg in zip(tool_calls, config_list, strict=False):
            state = self._extract_state(input)
            tool_runtime = ToolRuntime(
                state=state,
                tool_call_id=call["id"],
                config=cfg,
                context=runtime.context,
                store=runtime.store,
                stream_writer=runtime.stream_writer,
            )
            tool_runtimes.append(tool_runtime)

        # Pass original tool calls without injection
        input_types = [input_type] * len(tool_calls)
        with get_executor_for_config(config) as executor:
            outputs = list(
                executor.map(self._run_one, tool_calls, input_types, tool_runtimes)
            )

        return self._combine_tool_outputs(outputs, input_type)

    async def _afunc(
        self,
        input: list[AnyMessage] | dict[str, Any] | BaseModel,
        config: RunnableConfig,
        runtime: Runtime,
    ) -> Any:
        tool_calls, input_type = self._parse_input(input)
        config_list = get_config_list(config, len(tool_calls))

        # Construct ToolRuntime instances at the top level for each tool call
        tool_runtimes = []
        for call, cfg in zip(tool_calls, config_list, strict=False):
            state = self._extract_state(input)
            tool_runtime = ToolRuntime(
                state=state,
                tool_call_id=call["id"],
                config=cfg,
                context=runtime.context,
                store=runtime.store,
                stream_writer=runtime.stream_writer,
            )
            tool_runtimes.append(tool_runtime)

        # Pass original tool calls without injection
        coros = []
        for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
            coros.append(self._arun_one(call, input_type, tool_runtime))  # type: ignore[arg-type]
        outputs = await asyncio.gather(*coros)

        return self._combine_tool_outputs(outputs, input_type)

    def _combine_tool_outputs(
        self,
        outputs: list[ToolMessage | Command],
        input_type: Literal["list", "dict", "tool_calls"],
    ) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
        # preserve existing behavior for non-command tool outputs for backwards
        # compatibility
        if not any(isinstance(output, Command) for output in outputs):
            # TypedDict, pydantic, dataclass, etc. should all be able to load from dict
            return outputs if input_type == "list" else {self._messages_key: outputs}

        # LangGraph will automatically handle list of Command and non-command node
        # updates
        combined_outputs: list[
            Command | list[ToolMessage] | dict[str, list[ToolMessage]]
        ] = []

        # combine all parent commands with goto into a single parent command
        parent_command: Command | None = None
        for output in outputs:
            if isinstance(output, Command):
                if (
                    output.graph is Command.PARENT
                    and isinstance(output.goto, list)
                    and all(isinstance(send, Send) for send in output.goto)
                ):
                    if parent_command:
                        parent_command = replace(
                            parent_command,
                            goto=cast("list[Send]", parent_command.goto) + output.goto,
                        )
                    else:
                        parent_command = Command(graph=Command.PARENT, goto=output.goto)
                else:
                    combined_outputs.append(output)
            else:
                combined_outputs.append(
                    [output] if input_type == "list" else {self._messages_key: [output]}
                )

        if parent_command:
            combined_outputs.append(parent_command)
        return combined_outputs

    def _execute_tool_sync(
        self,
        request: ToolCallRequest,
        input_type: Literal["list", "dict", "tool_calls"],
        config: RunnableConfig,
    ) -> ToolMessage | Command:
        """Execute tool call with configured error handling.

        Args:
            request: Tool execution request.
            input_type: Input format.
            config: Runnable configuration.

        Returns:
            ToolMessage or Command.

        Raises:
            Exception: If tool fails and handle_tool_errors is False.
        """
        call = request.tool_call
        tool = request.tool

        # Validate tool exists when we actually need to execute it
        if tool is None:
            if invalid_tool_message := self._validate_tool_call(call):
                return invalid_tool_message
            # This should never happen if validation works correctly
            msg = f"Tool {call['name']} is not registered with ToolNode"
            raise TypeError(msg)

        # Inject state, store, and runtime right before invocation
        injected_call = self._inject_tool_args(call, request.runtime)
        call_args = {**injected_call, "type": "tool_call"}

        try:
            try:
                response = tool.invoke(call_args, config)
            except ValidationError as exc:
                # Filter out errors for injected arguments
                injected = self._injected_args.get(call["name"])
                filtered_errors = _filter_validation_errors(exc, injected)
                # Use original call["args"] without injected values for error reporting
                raise ToolInvocationError(
                    call["name"], exc, call["args"], filtered_errors
                ) from exc

        # GraphInterrupt is a special exception that will always be raised.
        # It can be triggered in the following scenarios,
        # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
        # most commonly:
        # (1) a GraphInterrupt is raised inside a tool
        # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
        # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
        #     called as a tool
        # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
        except GraphBubbleUp:
            raise
        except Exception as e:
            # Determine which exception types are handled
            handled_types: tuple[type[Exception], ...]
            if isinstance(self._handle_tool_errors, type) and issubclass(
                self._handle_tool_errors, Exception
            ):
                handled_types = (self._handle_tool_errors,)
            elif isinstance(self._handle_tool_errors, tuple):
                handled_types = self._handle_tool_errors
            elif callable(self._handle_tool_errors) and not isinstance(
                self._handle_tool_errors, type
            ):
                handled_types = _infer_handled_types(self._handle_tool_errors)
            else:
                # default behavior is catching all exceptions
                handled_types = (Exception,)

            # Check if this error should be handled
            if not self._handle_tool_errors or not isinstance(e, handled_types):
                raise

            # Error is handled - create error ToolMessage
            content = _handle_tool_error(e, flag=self._handle_tool_errors)
            return ToolMessage(
                content=content,
                name=call["name"],
                tool_call_id=call["id"],
                status="error",
            )

        # Process successful response
        if isinstance(response, Command):
            # Validate Command before returning to handler
            return self._validate_tool_command(response, request.tool_call, input_type)
        if isinstance(response, ToolMessage):
            response.content = cast("str | list", msg_content_output(response.content))
            return response

        msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
        raise TypeError(msg)

    def _run_one(
        self,
        call: ToolCall,
        input_type: Literal["list", "dict", "tool_calls"],
        tool_runtime: ToolRuntime,
    ) -> ToolMessage | Command:
        """Execute single tool call with wrap_tool_call wrapper if configured.

        Args:
            call: Tool call dict.
            input_type: Input format.
            tool_runtime: Tool runtime.

        Returns:
            ToolMessage or Command.
        """
        # Validation is deferred to _execute_tool_sync to allow interceptors
        # to short-circuit requests for unregistered tools
        tool = self.tools_by_name.get(call["name"])

        # Create the tool request with state and runtime
        tool_request = ToolCallRequest(
            tool_call=call,
            tool=tool,
            state=tool_runtime.state,
            runtime=tool_runtime,
        )

        config = tool_runtime.config

        if self._wrap_tool_call is None:
            # No wrapper - execute directly
            return self._execute_tool_sync(tool_request, input_type, config)

        # Define execute callable that can be called multiple times
        def execute(req: ToolCallRequest) -> ToolMessage | Command:
            """Execute tool with given request. Can be called multiple times."""
            return self._execute_tool_sync(req, input_type, config)

        # Call wrapper with request and execute callable
        try:
            return self._wrap_tool_call(tool_request, execute)
        except Exception as e:
            # Wrapper threw an exception
            if not self._handle_tool_errors:
                raise
            # Convert to error message
            content = _handle_tool_error(e, flag=self._handle_tool_errors)
            return ToolMessage(
                content=content,
                name=tool_request.tool_call["name"],
                tool_call_id=tool_request.tool_call["id"],
                status="error",
            )

    async def _execute_tool_async(
        self,
        request: ToolCallRequest,
        input_type: Literal["list", "dict", "tool_calls"],
        config: RunnableConfig,
    ) -> ToolMessage | Command:
        """Execute tool call asynchronously with configured error handling.

        Args:
            request: Tool execution request.
            input_type: Input format.
            config: Runnable configuration.

        Returns:
            ToolMessage or Command.

        Raises:
            Exception: If tool fails and handle_tool_errors is False.
        """
        call = request.tool_call
        tool = request.tool

        # Validate tool exists when we actually need to execute it
        if tool is None:
            if invalid_tool_message := self._validate_tool_call(call):
                return invalid_tool_message
            # This should never happen if validation works correctly
            msg = f"Tool {call['name']} is not registered with ToolNode"
            raise TypeError(msg)

        # Inject state, store, and runtime right before invocation
        injected_call = self._inject_tool_args(call, request.runtime)
        call_args = {**injected_call, "type": "tool_call"}

        try:
            try:
                response = await tool.ainvoke(call_args, config)
            except ValidationError as exc:
                # Filter out errors for injected arguments
                injected = self._injected_args.get(call["name"])
                filtered_errors = _filter_validation_errors(exc, injected)
                # Use original call["args"] without injected values for error reporting
                raise ToolInvocationError(
                    call["name"], exc, call["args"], filtered_errors
                ) from exc

        # GraphInterrupt is a special exception that will always be raised.
        # It can be triggered in the following scenarios,
        # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
        # most commonly:
        # (1) a GraphInterrupt is raised inside a tool
        # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
        # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
        #     called as a tool
        # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
        except GraphBubbleUp:
            raise
        except Exception as e:
            # Determine which exception types are handled
            handled_types: tuple[type[Exception], ...]
            if isinstance(self._handle_tool_errors, type) and issubclass(
                self._handle_tool_errors, Exception
            ):
                handled_types = (self._handle_tool_errors,)
            elif isinstance(self._handle_tool_errors, tuple):
                handled_types = self._handle_tool_errors
            elif callable(self._handle_tool_errors) and not isinstance(
                self._handle_tool_errors, type
            ):
                handled_types = _infer_handled_types(self._handle_tool_errors)
            else:
                # default behavior is catching all exceptions
                handled_types = (Exception,)

            # Check if this error should be handled
            if not self._handle_tool_errors or not isinstance(e, handled_types):
                raise

            # Error is handled - create error ToolMessage
            content = _handle_tool_error(e, flag=self._handle_tool_errors)
            return ToolMessage(
                content=content,
                name=call["name"],
                tool_call_id=call["id"],
                status="error",
            )

        # Process successful response
        if isinstance(response, Command):
            # Validate Command before returning to handler
            return self._validate_tool_command(response, request.tool_call, input_type)
        if isinstance(response, ToolMessage):
            response.content = cast("str | list", msg_content_output(response.content))
            return response

        msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
        raise TypeError(msg)

    async def _arun_one(
        self,
        call: ToolCall,
        input_type: Literal["list", "dict", "tool_calls"],
        tool_runtime: ToolRuntime,
    ) -> ToolMessage | Command:
        """Execute single tool call asynchronously with awrap_tool_call wrapper if configured.

        Args:
            call: Tool call dict.
            input_type: Input format.
            tool_runtime: Tool runtime.

        Returns:
            ToolMessage or Command.
        """
        # Validation is deferred to _execute_tool_async to allow interceptors
        # to short-circuit requests for unregistered tools
        tool = self.tools_by_name.get(call["name"])

        # Create the tool request with state and runtime
        tool_request = ToolCallRequest(
            tool_call=call,
            tool=tool,
            state=tool_runtime.state,
            runtime=tool_runtime,
        )

        config = tool_runtime.config

        if self._awrap_tool_call is None and self._wrap_tool_call is None:
            # No wrapper - execute directly
            return await self._execute_tool_async(tool_request, input_type, config)

        # Define async execute callable that can be called multiple times
        async def execute(req: ToolCallRequest) -> ToolMessage | Command:
            """Execute tool with given request. Can be called multiple times."""
            return await self._execute_tool_async(req, input_type, config)

        def _sync_execute(req: ToolCallRequest) -> ToolMessage | Command:
            """Sync execute fallback for sync wrapper."""
            return self._execute_tool_sync(req, input_type, config)

        # Call wrapper with request and execute callable
        try:
            if self._awrap_tool_call is not None:
                return await self._awrap_tool_call(tool_request, execute)
            # None check was performed above already
            self._wrap_tool_call = cast("ToolCallWrapper", self._wrap_tool_call)
            return self._wrap_tool_call(tool_request, _sync_execute)
        except Exception as e:
            # Wrapper threw an exception
            if not self._handle_tool_errors:
                raise
            # Convert to error message
            content = _handle_tool_error(e, flag=self._handle_tool_errors)
            return ToolMessage(
                content=content,
                name=tool_request.tool_call["name"],
                tool_call_id=tool_request.tool_call["id"],
                status="error",
            )

    def _parse_input(
        self,
        input: list[AnyMessage] | dict[str, Any] | BaseModel,
    ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
        input_type: Literal["list", "dict", "tool_calls"]
        if isinstance(input, list):
            if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call":
                input_type = "tool_calls"
                tool_calls = cast("list[ToolCall]", input)
                return tool_calls, input_type
            input_type = "list"
            messages = input
        elif (
            isinstance(input, dict) and input.get("__type") == "tool_call_with_context"
        ):
            # Handle ToolCallWithContext from Send API
            # mypy will not be able to type narrow correctly since the signature
            # for input contains dict[str, Any]. We'd need to narrow dict[str, Any]
            # before we can apply correct typing.
            input_with_ctx = cast("ToolCallWithContext", input)
            input_type = "tool_calls"
            return [input_with_ctx["tool_call"]], input_type
        elif isinstance(input, dict) and (
            messages := input.get(self._messages_key, [])
        ):
            input_type = "dict"
        elif messages := getattr(input, self._messages_key, []):
            # Assume dataclass-like state that can coerce from dict
            input_type = "dict"
        else:
            msg = "No message found in input"
            raise ValueError(msg)

        try:
            latest_ai_message = next(
                m for m in reversed(messages) if isinstance(m, AIMessage)
            )
        except StopIteration:
            msg = "No AIMessage found in input"
            raise ValueError(msg)

        tool_calls = list(latest_ai_message.tool_calls)
        return tool_calls, input_type

    def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None:
        requested_tool = call["name"]
        if requested_tool not in self.tools_by_name:
            all_tool_names = list(self.tools_by_name.keys())
            content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
                requested_tool=requested_tool,
                available_tools=", ".join(all_tool_names),
            )
            return ToolMessage(
                content, name=requested_tool, tool_call_id=call["id"], status="error"
            )
        return None

    def _extract_state(
        self, input: list[AnyMessage] | dict[str, Any] | BaseModel
    ) -> list[AnyMessage] | dict[str, Any] | BaseModel:
        """Extract state from input, handling ToolCallWithContext if present.

        Args:
            input: The input which may be raw state or ToolCallWithContext.

        Returns:
            The actual state to pass to wrap_tool_call wrappers.
        """
        if isinstance(input, dict) and input.get("__type") == "tool_call_with_context":
            return input["state"]
        return input

    def _inject_tool_args(
        self,
        tool_call: ToolCall,
        tool_runtime: ToolRuntime,
    ) -> ToolCall:
        """Inject graph state, store, and runtime into tool call arguments.

        This is an internal method that enables tools to access graph context that
        should not be controlled by the model. Tools can declare dependencies on graph
        state, persistent storage, or runtime context using InjectedState, InjectedStore,
        and ToolRuntime annotations. This method automatically identifies these
        dependencies and injects the appropriate values.

        The injection process preserves the original tool call structure while adding
        the necessary context arguments. This allows tools to be both model-callable
        and context-aware without exposing internal state management to the model.

        Args:
            tool_call: The tool call dictionary to augment with injected arguments.
                Must contain 'name', 'args', 'id', and 'type' fields.
            tool_runtime: The ToolRuntime instance containing all runtime context
                (state, config, store, context, stream_writer) to inject into tools.

        Returns:
            A new ToolCall dictionary with the same structure as the input but with
            additional arguments injected based on the tool's annotation requirements.

        Raises:
            ValueError: If a tool requires store injection but no store is provided,
                or if state injection requirements cannot be satisfied.

        !!! note
            This method is called automatically during tool execution. It should not
            be called from outside the `ToolNode`.
        """
        if tool_call["name"] not in self.tools_by_name:
            return tool_call

        injected = self._injected_args.get(tool_call["name"])
        if not injected:
            return tool_call

        tool_call_copy: ToolCall = copy(tool_call)
        injected_args = {}

        # Inject state
        if injected.state:
            state = tool_runtime.state
            # Handle list state by converting to dict
            if isinstance(state, list):
                required_fields = list(injected.state.values())
                if (
                    len(required_fields) == 1
                    and required_fields[0] == self._messages_key
                ) or required_fields[0] is None:
                    state = {self._messages_key: state}
                else:
                    err_msg = (
                        f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
                        f"graph state dict as input."
                    )
                    if any(state_field for state_field in injected.state.values()):
                        required_fields_str = ", ".join(f for f in required_fields if f)
                        err_msg += (
                            f" State should contain fields {required_fields_str}."
                        )
                    raise ValueError(err_msg)

            # Extract state values
            if isinstance(state, dict):
                for tool_arg, state_field in injected.state.items():
                    injected_args[tool_arg] = (
                        state[state_field] if state_field else state
                    )
            else:
                for tool_arg, state_field in injected.state.items():
                    injected_args[tool_arg] = (
                        getattr(state, state_field) if state_field else state
                    )

        # Inject store
        if injected.store:
            if tool_runtime.store is None:
                msg = (
                    "Cannot inject store into tools with InjectedStore annotations - "
                    "please compile your graph with a store."
                )
                raise ValueError(msg)
            injected_args[injected.store] = tool_runtime.store

        # Inject runtime
        if injected.runtime:
            injected_args[injected.runtime] = tool_runtime

        tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args}
        return tool_call_copy

    def _validate_tool_command(
        self,
        command: Command,
        call: ToolCall,
        input_type: Literal["list", "dict", "tool_calls"],
    ) -> Command:
        if isinstance(command.update, dict):
            # input type is dict when ToolNode is invoked with a dict input
            # (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
            if input_type not in ("dict", "tool_calls"):
                msg = (
                    "Tools can provide a dict in Command.update only when using dict "
                    f"with '{self._messages_key}' key as ToolNode input, "
                    f"got: {command.update} for tool '{call['name']}'"
                )
                raise ValueError(msg)

            updated_command = deepcopy(command)
            state_update = cast("dict[str, Any]", updated_command.update) or {}
            messages_update = state_update.get(self._messages_key, [])
        elif isinstance(command.update, list):
            # Input type is list when ToolNode is invoked with a list input
            # (e.g. [AIMessage(..., tool_calls=[...])])
            if input_type != "list":
                msg = (
                    "Tools can provide a list of messages in Command.update "
                    "only when using list of messages as ToolNode input, "
                    f"got: {command.update} for tool '{call['name']}'"
                )
                raise ValueError(msg)

            updated_command = deepcopy(command)
            messages_update = updated_command.update
        else:
            return command

        # convert to message objects if updates are in a dict format
        messages_update = convert_to_messages(messages_update)

        # no validation needed if all messages are being removed
        if messages_update == [RemoveMessage(id=REMOVE_ALL_MESSAGES)]:
            return updated_command

        has_matching_tool_message = False
        for message in messages_update:
            if not isinstance(message, ToolMessage):
                continue

            if message.tool_call_id == call["id"]:
                message.name = call["name"]
                has_matching_tool_message = True

        # validate that we always have a ToolMessage matching the tool call in
        # Command.update if command is sent to the CURRENT graph
        if updated_command.graph is None and not has_matching_tool_message:
            example_update = (
                '`Command(update={"messages": '
                '[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
                if input_type == "dict"
                else "`Command(update="
                '[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
            )
            msg = (
                "Expected to have a matching ToolMessage in Command.update "
                f"for tool '{call['name']}', got: {messages_update}. "
                "Every tool call (LLM requesting to call a tool) "
                "in the message history MUST have a corresponding ToolMessage. "
                f"You can fix it by modifying the tool to return {example_update}."
            )
            raise ValueError(msg)
        return updated_command


def tools_condition(
    state: list[AnyMessage] | dict[str, Any] | BaseModel,
    messages_key: str = "messages",
) -> Literal["tools", "__end__"]:
    """Conditional routing function for tool-calling workflows.

    This utility function implements the standard conditional logic for ReAct-style
    agents: if the last `AIMessage` contains tool calls, route to the tool execution
    node; otherwise, end the workflow. This pattern is fundamental to most tool-calling
    agent architectures.

    The function handles multiple state formats commonly used in LangGraph applications,
    making it flexible for different graph designs while maintaining consistent behavior.

    Args:
        state: The current graph state to examine for tool calls. Supported formats:
            - Dictionary containing a messages key (for `StateGraph`)
            - `BaseModel` instance with a messages attribute
        messages_key: The key or attribute name containing the message list in the state.
            This allows customization for graphs using different state schemas.

    Returns:
        Either `'tools'` if tool calls are present in the last `AIMessage`, or `'__end__'`
            to terminate the workflow. These are the standard routing destinations for
            tool-calling conditional edges.

    Raises:
        ValueError: If no messages can be found in the provided state format.

    Example:
        Basic usage in a ReAct agent:

        ```python
        from langgraph.graph import StateGraph
        from langchain.tools import ToolNode
        from langchain.tools.tool_node import tools_condition
        from typing_extensions import TypedDict


        class State(TypedDict):
            messages: list


        graph = StateGraph(State)
        graph.add_node("llm", call_model)
        graph.add_node("tools", ToolNode([my_tool]))
        graph.add_conditional_edges(
            "llm",
            tools_condition,  # Routes to "tools" or "__end__"
            {"tools": "tools", "__end__": "__end__"},
        )
        ```

        Custom messages key:

        ```python
        def custom_condition(state):
            return tools_condition(state, messages_key="chat_history")
        ```

    !!! note
        This function is designed to work seamlessly with `ToolNode` and standard
        LangGraph patterns. It expects the last message to be an `AIMessage` when
        tool calls are present, which is the standard output format for tool-calling
        language models.
    """
    if isinstance(state, list):
        ai_message = state[-1]
    elif (isinstance(state, dict) and (messages := state.get(messages_key, []))) or (
        messages := getattr(state, messages_key, [])
    ):
        ai_message = messages[-1]
    else:
        msg = f"No messages found in input state to tool_edge: {state}"
        raise ValueError(msg)
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "__end__"


@dataclass
class ToolRuntime(_DirectlyInjectedToolArg, Generic[ContextT, StateT]):
    """Runtime context automatically injected into tools.

    When a tool function has a parameter named `tool_runtime` with type hint
    `ToolRuntime`, the tool execution system will automatically inject an instance
    containing:

    - `state`: The current graph state
    - `tool_call_id`: The ID of the current tool call
    - `config`: `RunnableConfig` for the current execution
    - `context`: Runtime context (from langgraph `Runtime`)
    - `store`: `BaseStore` instance for persistent storage (from langgraph `Runtime`)
    - `stream_writer`: `StreamWriter` for streaming output (from langgraph `Runtime`)

    No `Annotated` wrapper is needed - just use `runtime: ToolRuntime`
    as a parameter.

    Example:
        ```python
        from langchain_core.tools import tool
        from langchain.tools import ToolRuntime

        @tool
        def my_tool(x: int, runtime: ToolRuntime) -> str:
            \"\"\"Tool that accesses runtime context.\"\"\"
            # Access state
            messages = tool_runtime.state["messages"]

            # Access tool_call_id
            print(f"Tool call ID: {tool_runtime.tool_call_id}")

            # Access config
            print(f"Run ID: {tool_runtime.config.get('run_id')}")

            # Access runtime context
            user_id = tool_runtime.context.get("user_id")

            # Access store
            tool_runtime.store.put(("metrics",), "count", 1)

            # Stream output
            tool_runtime.stream_writer.write("Processing...")

            return f"Processed {x}"
        ```

    !!! note
        This is a marker class used for type checking and detection.
        The actual runtime object will be constructed during tool execution.
    """

    state: StateT
    context: ContextT
    config: RunnableConfig
    stream_writer: StreamWriter
    tool_call_id: str | None
    store: BaseStore | None


class InjectedState(InjectedToolArg):
    """Annotation for injecting graph state into tool arguments.

    This annotation enables tools to access graph state without exposing state
    management details to the language model. Tools annotated with `InjectedState`
    receive state data automatically during execution while remaining invisible
    to the model's tool-calling interface.

    Args:
        field: Optional key to extract from the state dictionary. If `None`, the entire
            state is injected. If specified, only that field's value is injected.
            This allows tools to request specific state components rather than
            processing the full state structure.

    Example:
        ```python
        from typing import List
        from typing_extensions import Annotated, TypedDict

        from langchain_core.messages import BaseMessage, AIMessage
        from langchain.tools import InjectedState, ToolNode, tool


        class AgentState(TypedDict):
            messages: List[BaseMessage]
            foo: str


        @tool
        def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
            '''Do something with state.'''
            if len(state["messages"]) > 2:
                return state["foo"] + str(x)
            else:
                return "not enough messages"


        @tool
        def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
            '''Do something else with state.'''
            return foo + str(x + 1)


        node = ToolNode([state_tool, foo_tool])

        tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
        tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"}
        state = {
            "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])],
            "foo": "bar",
        }
        node.invoke(state)
        ```

        ```python
        [
            ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"),
            ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"),
        ]
        ```

    !!! note
        - `InjectedState` arguments are automatically excluded from tool schemas
            presented to language models
        - `ToolNode` handles the injection process during execution
        - Tools can mix regular arguments (controlled by the model) with injected
            arguments (controlled by the system)
        - State injection occurs after the model generates tool calls but before
            tool execution
    """

    def __init__(self, field: str | None = None) -> None:
        """Initialize the `InjectedState` annotation."""
        self.field = field


class InjectedStore(InjectedToolArg):
    """Annotation for injecting persistent store into tool arguments.

    This annotation enables tools to access LangGraph's persistent storage system
    without exposing storage details to the language model. Tools annotated with
    `InjectedStore` receive the store instance automatically during execution while
    remaining invisible to the model's tool-calling interface.

    The store provides persistent, cross-session data storage that tools can use
    for maintaining context, user preferences, or any other data that needs to
    persist beyond individual workflow executions.

    !!! warning
        `InjectedStore` annotation requires `langchain-core >= 0.3.8`

    Example:
        ```python
        from typing_extensions import Annotated
        from langgraph.store.memory import InMemoryStore
        from langchain.tools import InjectedStore, ToolNode, tool

        @tool
        def save_preference(
            key: str,
            value: str,
            store: Annotated[Any, InjectedStore()]
        ) -> str:
            \"\"\"Save user preference to persistent storage.\"\"\"
            store.put(("preferences",), key, value)
            return f"Saved {key} = {value}"

        @tool
        def get_preference(
            key: str,
            store: Annotated[Any, InjectedStore()]
        ) -> str:
            \"\"\"Retrieve user preference from persistent storage.\"\"\"
            result = store.get(("preferences",), key)
            return result.value if result else "Not found"
        ```

        Usage with `ToolNode` and graph compilation:

        ```python
        from langgraph.graph import StateGraph
        from langgraph.store.memory import InMemoryStore

        store = InMemoryStore()
        tool_node = ToolNode([save_preference, get_preference])

        graph = StateGraph(State)
        graph.add_node("tools", tool_node)
        compiled_graph = graph.compile(store=store)  # Store is injected automatically
        ```

        Cross-session persistence:

        ```python
        # First session
        result1 = graph.invoke({"messages": [HumanMessage("Save my favorite color as blue")]})

        # Later session - data persists
        result2 = graph.invoke({"messages": [HumanMessage("What's my favorite color?")]})
        ```

    !!! note
        - `InjectedStore` arguments are automatically excluded from tool schemas
            presented to language models
        - The store instance is automatically injected by `ToolNode` during execution
        - Tools can access namespaced storage using the store's get/put methods
        - Store injection requires the graph to be compiled with a store instance
        - Multiple tools can share the same store instance for data consistency
    """


def _is_injection(
    type_arg: Any,
    injection_type: type[InjectedState | InjectedStore | ToolRuntime],
) -> bool:
    """Check if a type argument represents an injection annotation.

    This utility function determines whether a type annotation indicates that
    an argument should be injected with state or store data. It handles both
    direct annotations and nested annotations within Union or Annotated types.

    Args:
        type_arg: The type argument to check for injection annotations.
        injection_type: The injection type to look for (InjectedState or InjectedStore).

    Returns:
        True if the type argument contains the specified injection annotation.
    """
    if isinstance(type_arg, injection_type) or (
        isinstance(type_arg, type) and issubclass(type_arg, injection_type)
    ):
        return True
    origin_ = get_origin(type_arg)
    if origin_ is Union or origin_ is Annotated:
        return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
    return False


def _get_injection_from_type(
    type_: Any, injection_type: type[InjectedState | InjectedStore | ToolRuntime]
) -> Any | None:
    """Extract injection instance from a type annotation.

    Args:
        type_: The type annotation to check.
        injection_type: The injection type to look for.

    Returns:
        The injection instance if found, True if injection marker found without instance, None otherwise.
    """
    type_args = get_args(type_)
    matches = [arg for arg in type_args if _is_injection(arg, injection_type)]

    if len(matches) > 1:
        msg = (
            f"A tool argument should not be annotated with {injection_type.__name__} "
            f"more than once. Found: {matches}"
        )
        raise ValueError(msg)

    if len(matches) == 1:
        return matches[0]
    elif _is_injection(type_, injection_type):
        return True

    return None


def _get_all_injected_args(tool: BaseTool) -> _InjectedArgs:
    """Extract all injected arguments from tool in a single pass.

    This function analyzes both the tool's input schema and function signature
    to identify all arguments that should be injected (state, store, runtime).

    Args:
        tool: The tool to analyze for injection requirements.

    Returns:
        _InjectedArgs structure containing all detected injections.
    """
    # Get annotations from both schema and function signature
    full_schema = tool.get_input_schema()
    schema_annotations = get_all_basemodel_annotations(full_schema)

    func = getattr(tool, "func", None) or getattr(tool, "coroutine", None)
    func_annotations = get_type_hints(func, include_extras=True) if func else {}

    # Combine both annotation sources, preferring schema annotations
    # In the future, we might want to add more restrictions here...
    all_annotations = {**func_annotations, **schema_annotations}

    # Track injected args
    state_args: dict[str, str | None] = {}
    store_arg: str | None = None
    runtime_arg: str | None = None

    for name, type_ in all_annotations.items():
        # Check for runtime (special case: parameter named "runtime")
        if name == "runtime":
            runtime_arg = name

        # Check for InjectedState
        if state_inj := _get_injection_from_type(type_, InjectedState):
            if isinstance(state_inj, InjectedState) and state_inj.field:
                state_args[name] = state_inj.field
            else:
                state_args[name] = None

        # Check for InjectedStore
        if _get_injection_from_type(type_, InjectedStore):
            store_arg = name

        # Check for ToolRuntime
        if _get_injection_from_type(type_, ToolRuntime):
            runtime_arg = name

    return _InjectedArgs(
        state=state_args,
        store=store_arg,
        runtime=runtime_arg,
    )
